import torch
import torch.nn as nn
from .base_embedding import BaseEmbedding

class DalleTextEmbedding(BaseEmbedding):
    def __init__(self, 
                 num_embed, 
                 embed_dim,
                 seq_len=256, 
                 trainable=True,
                 pos_emb_type='embedding',
                 mask_embedding=False,
        ):
        super().__init__()
        
        self.num_embed = num_embed
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.trainable = trainable
        self.pos_emb_type = pos_emb_type
        self.mask_embedding = mask_embedding

        assert self.pos_emb_type in ['embedding', 'parameter']
        if not self.mask_embedding:
            print('{} The embedding will not be masked!'.format(self.__class__.__name__))
        
        self.emb = nn.Embedding(num_embed, embed_dim)
        if self.pos_emb_type == 'embedding':
            self.pos_emb = nn.Embedding(seq_len, embed_dim)
        else:
            self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, embed_dim))

        if self.mask_embedding:
            self.pad_emb = nn.Embedding(seq_len, embed_dim)
        
        self._set_trainable()

    def forward(self, index, mask=None, **kwargs):
        """
        index: B x L 
        mask: B x L, bool type. The value of False indicating padded index
        """
        assert index.dim() == 2 # B x L
        try:
            index[index < 0] = 0  # some padded token maybe negative, so set them to 0
            emb = self.emb(index) # B x L x D
        except:
            raise RuntimeError('IndexError: index out of range in self, max index {}, num embed {}'.format(index.max(), self.num_embed))
            
        if self.pos_emb_type == 'embedding':
            pos_emb = self.pos_emb(torch.arange(index.shape[1], device=index.device).view(1, index.shape[1])) # 1 x L x D
        else:
            pos_emb = self.pos_emb[:, :index.shape[1], :]
        emb = emb + pos_emb
        # import pdb; pdb.set_trace()
        if self.mask_embedding:
        # if mask is not None and self.mask_embedding:
            pad_emb = self.pad_emb(torch.arange(index.shape[1], device=index.device).view(1, index.shape[1])) # 1 x L x D

            mask_ = mask.unsqueeze(-1).to(emb) # B x L x 1
            emb = emb * mask_
            pad_emb = pad_emb * (1 - mask_) 
            
            emb = emb + pad_emb

        return emb
